import copy
import numpy as np
from allennlp.data import Token, Instance

from allennlp.predictors import Predictor
import allennlp.nn.util as util
from allennlp.data.fields import TextField, LabelField

from .attacker import Attacker
from config import Config


class HotFlipAttacker(Attacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(HotFlipAttacker, self).__init__(cf, predictor)
        self.model = predictor._model
        self.reader = predictor._dataset_reader

        self.embedding_matrix = util.find_embedding_layer(self.model).weight

    def forward(self, instance,
                grad_input_field: str = 'grad_input_1'):

        attack_num = self.attack_num(len(instance['sentence']))

        adv_instance = instance
        flipped = []
        result = self.attack_result(success=False, length=0, substitudes=[])

        output = self.predictor.predict_instance(adv_instance)

        for i in range(attack_num + 1):

            stop = self.stop([output])[0]

            if stop:
                result['success'] = True
                result['length'] = i
                break

            gradients = self.predictor.get_gradients([adv_instance])
            input_grad = gradients[grad_input_field][0]

            grad_magnitude = (input_grad ** 2).sum(dim=-1)

            # only flip a token once
            for idx in flipped:
                grad_magnitude[idx] = -1

            # We flip the token with highest gradient norm.
            flip_idx = grad_magnitude.argmax()

            flipped.append(flip_idx)

            origin_id = self.model.vocab.get_token_index(adv_instance['sentence'][flip_idx])
            origin_embedding = self.embedding_matrix[origin_id]
            grad = input_grad[flip_idx]

            direction = (self.embedding_matrix - origin_embedding) @ grad

            idx_candidate = direction.argmax().item()
            candidate = self.model.vocab.get_token_from_index(idx_candidate)

            adv_instance = self.subsitude(adv_instance, flip_idx, candidate)
            output = self.predictor.predict_instance(adv_instance)

            result['substitutes'].append(
                [instance['sentence'].tokens[i].text, candidate, flip_idx, output['gold_prob']])

        adv_instance = self.reader.instance_to_text(adv_instance)

        result['gold'] = output['gold']
        result['pred'] = output['pred']
        result['adv_example'] = adv_instance
        return result

    def subsitude(self, instance, i, word):
        new_instance = copy.deepcopy(instance)
        new_instance['sentence'].tokens[i] = Token(word)
        new_instance.indexed = False
        return new_instance
